I'm currently in the process of implementing a CGAN with convolutions and have written a discriminator, but I'm uncertain if my code is correct as the discriminator loss immediately drops to zero while the generator loss continues to increase. Could you kindly review my code for the discriminator?
# Define discriminator network
class Discriminator(nn.Module):
def __init__(self, num_classes):
super(Discriminator, self).__init__()
self.num_classes = num_classes
self.label_emb = nn.Embedding(num_classes, num_classes)
self.conv1 = nn.Sequential(
nn.Conv2d(3 + num_classes, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True)
)
self.conv3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True)
)
self.fc = nn.Sequential(
nn.Linear(256 * 4 * 4, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
label_emb = self.label_emb(labels) # shape: (batch_size, num_classes)
label_emb = label_emb.view(label_emb.size(0), label_emb.size(1), 1, 1) # shape: (batch_size, num_classes, 1, 1)
label_emb = label_emb.expand(-1, -1, img.size(2), img.size(3)) # shape: (batch_size, num_classes, img_height, img_width)
dis_input = torch.cat((img, label_emb), dim=1) # shape: (batch_size, 1 + num_classes, img_height, img_width)
x = self.conv1(dis_input)
x = self.conv2(x)
x = self.conv3(x)
x = x.view(x.shape[0], -1)
x = self.fc(x)
return x
submitted by /u/odbhut_shei_chhele
[link] [comments]
( 8
min )